{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dXPeGxzWPA73"
      },
      "source": [
        "# Demo for annotate a point track with optical flow\n",
        "\n",
        "This notebook illustrates how we use optical flow to facilitate human annotation on point tracking. Note that it is very hard to annotate a point track extensively along a whole video sequence. However we find dense optical flow estimation these days are fast and accurate. In this demo, we utilize [RAFT](https://pytorch.org/vision/stable/auto_examples/plot_optical_flow.html) to compute the dense optical flow for us.\n",
        "\n",
        "We then ask the annotater to select a point in the starting frame and the corresponding point location in the ending frame. A dynamic programming algorithm is used to optimize the estimated tracks given starting and ending point location. Note that the algorithm here differs from what we use in the original annotation system (dijkstra algorithm).\n",
        "\n",
        "The dynamic programming algorithm here requires large matrix computation. Hence running on GPU will be a lot faster."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Wv9x5NjJzm54"
      },
      "outputs": [],
      "source": [
        "!pip install mediapy mako flow_vis"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZxsjTxUqzrwV"
      },
      "outputs": [],
      "source": [
        "# @title Imports {form-width: \"25%\"}\n",
        "\n",
        "import copy\n",
        "import io\n",
        "import flow_vis\n",
        "import functools\n",
        "import gc\n",
        "import IPython\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "from PIL import Image\n",
        "from google.colab import html\n",
        "import base64\n",
        "from mako import template\n",
        "import torch\n",
        "import torchvision\n",
        "from tqdm import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p3FnMfsszsMb"
      },
      "outputs": [],
      "source": [
        "# If you can, run this example on a GPU, it will be a lot faster.\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "torch.set_grad_enabled(False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5eKOeizEztfG"
      },
      "outputs": [],
      "source": [
        "# @title Load an Exemplar Video {form-width: \"25%\"}\n",
        "\n",
        "!wget -P tapnet/examplar_videos https://storage.googleapis.com/dm-tapnet/horsejump-high.mp4\n",
        "\n",
        "video = media.read_video('tapnet/examplar_videos/horsejump-high.mp4')\n",
        "video = media.resize_video(video, (480, 768))\n",
        "height, width = video.shape[1:3]\n",
        "media.show_video(video, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "m_LydwEBzu15"
      },
      "outputs": [],
      "source": [
        "# @title Predict Optical Flows with RAFT {form-width: \"25%\"}\n",
        "\n",
        "from torchvision.models.optical_flow import raft_large\n",
        "from torchvision.models.optical_flow import Raft_Large_Weights\n",
        "\n",
        "model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)\n",
        "model = model.eval()\n",
        "\n",
        "optical_flows = []\n",
        "for i in tqdm(range(video.shape[0] - 1)):\n",
        "  image1 = video[i].astype(np.float32) / 127.5 - 1.0\n",
        "  image1 = image1.transpose(2, 0, 1)[None]\n",
        "  image2 = video[i + 1].astype(np.float32) / 127.5 - 1.0\n",
        "  image2 = image2.transpose(2, 0, 1)[None]\n",
        "  flow = model(torch.tensor(image1).to(device), torch.tensor(image2).to(device))\n",
        "  flow = flow[-1][0].cpu().numpy()\n",
        "  flow = flow.transpose(1, 2, 0)\n",
        "  optical_flows.append(flow)\n",
        "optical_flows = np.stack(optical_flows)\n",
        "\n",
        "# Release Memory after Prediction\n",
        "del model\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n",
        "\n",
        "print(optical_flows.shape)\n",
        "print(np.abs(optical_flows).max())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_tzMX-Yizwy9"
      },
      "outputs": [],
      "source": [
        "# @title Visualize Optical Flows {form-width: \"25%\"}\n",
        "\n",
        "flow_viz = []\n",
        "for i in range(optical_flows.shape[0]):\n",
        "  flow_viz.append(flow_vis.flow_to_color(optical_flows[i]))\n",
        "flow_viz = np.stack(flow_viz)\n",
        "\n",
        "media.show_video(flow_viz, fps=10)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5u_YkBmrzyOL"
      },
      "outputs": [],
      "source": [
        "# @title HTML Template {form-width: \"25%\"}\n",
        "\n",
        "class Img(html.Element):\n",
        "  def __init__(self, src=None, show=False):\n",
        "    super(Img, self).__init__('img')\n",
        "    if src is not None:\n",
        "      self.src = src\n",
        "    self.set_attribute('style', ('display:block;' if show else 'display:none;')+'margin:0px;')\n",
        "\n",
        "  @property\n",
        "  def src(self):\n",
        "    return self.get_property('src')\n",
        "\n",
        "  @src.setter\n",
        "  def src(self, value):\n",
        "    content = self._to_jpeg(value)\n",
        "    url = 'data:image/jpeg;base64,' + base64.b64encode(content).decode('utf-8')\n",
        "    self.set_property('src', url)\n",
        "\n",
        "  def _to_jpeg(self, np_image):\n",
        "    img = Image.fromarray(np_image)\n",
        "    buf = io.BytesIO()\n",
        "    img.save(buf, format=\"JPEG\")\n",
        "    return buf.getvalue()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mIbQG5Ycz5xu"
      },
      "outputs": [],
      "source": [
        "# @title Dynamic Programming Algorithm {form-width: \"25%\"}\n",
        "\n",
        "def interpolate(flows, frame1, click1, frame2, click2, radius=20):\n",
        "  x1, y1 = click2idx(click1)\n",
        "  x2, y2 = click2idx(click2)\n",
        "\n",
        "  window = 2 * radius + 1\n",
        "  x, y = np.meshgrid(np.arange(-radius, radius + 1), np.arange(-radius, radius + 1))\n",
        "  offset_cost = np.stack([x, y], axis=-1)\n",
        "  offset_cost = torch.tensor(offset_cost).to(device)\n",
        "\n",
        "  num_frames, height, width = flows.shape[0:3]\n",
        "\n",
        "  forward_i = np.zeros((num_frames + 1, height, width), dtype=np.int32)\n",
        "  forward_j = np.zeros((num_frames + 1, height, width), dtype=np.int32)\n",
        "\n",
        "  forward_cost = torch.ones((height, width)).to(device) * 1e10\n",
        "  forward_cost[y1, x1] = 0\n",
        "\n",
        "  for t in range(frame1, frame2):\n",
        "    cost_pad = torch.nn.functional.pad(forward_cost, (radius, radius, radius, radius), 'constant', value=1e10)\n",
        "    cost_unfold = cost_pad.unfold(0, window, 1).unfold(1, window, 1)\n",
        "    del cost_pad\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    flow_cuda = torch.tensor(flows[t]).to(device)\n",
        "    flow_pad = torch.nn.functional.pad(flow_cuda, (0, 0, radius, radius, radius, radius), 'constant', value=1e10)\n",
        "    flow_unfold = flow_pad.unfold(0, window, 1).unfold(1, window, 1).permute(0, 1, 3, 4, 2)\n",
        "    del flow_cuda, flow_pad\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    cost = cost_unfold + torch.abs(-offset_cost[None, None] - flow_unfold).sum(axis=-1)\n",
        "    cost = cost.reshape(height, width, -1)\n",
        "    forward_cost, argmin_indices = torch.min(cost, axis=-1)\n",
        "    del cost\n",
        "    gc.collect()\n",
        "    torch.cuda.empty_cache()\n",
        "\n",
        "    argmin_indices = argmin_indices.cpu().numpy()\n",
        "    forward_i_min, forward_j_min = argmin_indices // (window), argmin_indices % (window)\n",
        "    forward_i[t] = forward_i_min + np.arange(height)[:, None] - radius\n",
        "    forward_j[t] = forward_j_min + np.arange(width)[None] - radius\n",
        "\n",
        "  last_cost = torch.ones((height, width)).to(device) * 1e10\n",
        "  last_cost[y2, x2] = 0\n",
        "  forward_cost += last_cost\n",
        "  min_cost = torch.min(forward_cost).cpu().numpy()\n",
        "\n",
        "  argmin_indices = torch.argmin(forward_cost).item()\n",
        "  min_i, min_j = argmin_indices // width, argmin_indices % width\n",
        "  min_ij = [(min_j, min_i)]\n",
        "\n",
        "  for t in range(frame2 - 1, frame1 - 1, -1):\n",
        "    min_i, min_j = forward_i[t, min_i, min_j], forward_j[t, min_i, min_j]\n",
        "    min_ij.insert(0, (min_j, min_i))\n",
        "\n",
        "  del forward_cost\n",
        "  gc.collect()\n",
        "  torch.cuda.empty_cache()\n",
        "\n",
        "  return np.stack(min_ij), min_cost"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SklmS5UGz6S1"
      },
      "outputs": [],
      "source": [
        "# @title Reset the Annotated Trajectories {form-width: \"25%\"}\n",
        "\n",
        "clicks=[None for i in range(video.shape[0])]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LCguhULPz8qT"
      },
      "outputs": [],
      "source": [
        "# @title Start Annotation {form-width: \"25%\"}\n",
        "\n",
        "def mouse_position(event, frame_id):\n",
        "  x = event['clientX']\n",
        "  y = event['clientY']\n",
        "  clicks[frame_id]=[x, y]\n",
        "  print('\\r', 'Please re-run this cell ...', end='')\n",
        "\n",
        "def click2idx(click):\n",
        "  x, y = click\n",
        "  x = int(round(x))\n",
        "  y = int(round(y))\n",
        "  return x, y\n",
        "\n",
        "cur_pos = None\n",
        "frames2 = []\n",
        "all_pos = np.zeros([video.shape[0], 2], dtype=int)\n",
        "last_click = None\n",
        "for i in range(video.shape[0]):\n",
        "  if clicks[i] and last_click:\n",
        "    all_pos[last_click[0]:i+1, :], forward_cost = interpolate(optical_flows, last_click[0], last_click[1], i, clicks[i])\n",
        "\n",
        "  if clicks[i]:\n",
        "    cur_pos = copy.copy(clicks[i])\n",
        "    last_click = (i, clicks[i])\n",
        "  if cur_pos:\n",
        "    x, y = click2idx(cur_pos)\n",
        "\n",
        "    y = min(max(y, 0), height - 1)\n",
        "    x = min(max(x, 0), width - 1)\n",
        "    all_pos[i,0] = x\n",
        "    all_pos[i,1] = y\n",
        "    if i \u003c optical_flows.shape[0]:\n",
        "      cur_pos[0] += optical_flows[i, y, x, 0]\n",
        "      cur_pos[1] += optical_flows[i, y, x, 1]\n",
        "\n",
        "for i in range(video.shape[0]):\n",
        "  fr = np.copy(video[i])\n",
        "  x, y = all_pos[i] - 5\n",
        "  fr[y-2:y+3,x-2:x+3,0] = 255 if clicks[i] else 0\n",
        "  fr[y-2:y+3,x-2:x+3,1] = 0 if clicks[i] else 255\n",
        "  fr[y-2:y+3,x-2:x+3,2] = 0 if clicks[i] else 255\n",
        "  frames2.append(fr)\n",
        "\n",
        "imgs=[]\n",
        "img_ids=\"[\"\n",
        "for i in range(len(frames2)):\n",
        "  img = Img(src=frames2[i], show=i==0)\n",
        "  img.add_event_listener('click', functools.partial(mouse_position, frame_id=i))\n",
        "  imgs.append(img)\n",
        "  img_ids += \"\\\"\" + str(img._guid) + \"\\\",\"\n",
        "img_ids += \"]\"\n",
        "\n",
        "MAKO_TEMPLATE=\"\"\"\n",
        "\u003cinput type=\"range\" min=\"0\" max=\"${num_frames-1}\" value=\"0\" class=\"slider\" id=\"myRange\"\u003e\n",
        "\u003cscript\u003e\n",
        "img_ids=${img_ids}\n",
        "slider=document.getElementById(\"myRange\");\n",
        "cur_frame=0\n",
        "slider.oninput = function() {\n",
        "  idx = this.value;\n",
        "  for (var i = 0; i\u003c${num_frames}; i++){\n",
        "    document.getElementById(img_ids[i]).style.display=\"none\"\n",
        "  }\n",
        "  document.getElementById(img_ids[idx]).style.display=\"block\"\n",
        "}\n",
        "\u003c/script\u003e\n",
        "\"\"\"\n",
        "viz_tpl = template.Template(MAKO_TEMPLATE, strict_undefined=True)\n",
        "script = viz_tpl.render(num_frames=len(frames2),img_ids=img_ids)\n",
        "\n",
        "display(IPython.display.HTML(\" \".join([img._repr_html_() for img in imgs])+script))\n",
        "\n",
        "################################################################################\n",
        "# Instructions:\n",
        "#\n",
        "# 1) click anywhere on the first frame to get a point to track.\n",
        "# 2) re-run this cell to see where it goes\n",
        "# 3) click a point on any other frame, and the demo will find the shortest path.\n",
        "################################################################################"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": []
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
